{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/lamb_optimizer.py:34: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from albert import modeling\n",
    "from albert import optimization\n",
    "from albert import tokenization\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:loading sentence piece model\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file='albert-tiny-2020-04-17/sp10m.cased.v10.vocab', do_lower_case=False,\n",
    "      spm_model_file='albert-tiny-2020-04-17/sp10m.cased.v10.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:116: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<albert.modeling.AlbertConfig at 0x7faabec00080>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "albert_config = modeling.AlbertConfig.from_json_file('albert-tiny-2020-04-17/config.json')\n",
    "albert_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rules import normalized_chars\n",
    "import random\n",
    "from unidecode import unidecode\n",
    "import re\n",
    "\n",
    "laughing = {\n",
    "    'huhu',\n",
    "    'haha',\n",
    "    'gagaga',\n",
    "    'hihi',\n",
    "    'wkawka',\n",
    "    'wkwk',\n",
    "    'kiki',\n",
    "    'keke',\n",
    "    'huehue',\n",
    "    'hshs',\n",
    "    'hoho',\n",
    "    'hewhew',\n",
    "    'uwu',\n",
    "    'sksk',\n",
    "    'ksks',\n",
    "    'gituu',\n",
    "    'gitu',\n",
    "    'mmeeooww',\n",
    "    'meow',\n",
    "    'alhamdulillah',\n",
    "    'muah',\n",
    "    'mmuahh',\n",
    "    'hehe',\n",
    "    'salamramadhan',\n",
    "    'happywomensday',\n",
    "    'jahagaha',\n",
    "    'ahakss',\n",
    "    'ahksk'\n",
    "}\n",
    "\n",
    "def make_cleaning(s, c_dict):\n",
    "    s = s.translate(c_dict)\n",
    "    return s\n",
    "\n",
    "def cleaning(string):\n",
    "    \"\"\"\n",
    "    use by any transformer model before tokenization\n",
    "    \"\"\"\n",
    "    string = unidecode(string)\n",
    "    \n",
    "    string = ' '.join(\n",
    "        [make_cleaning(w, normalized_chars) for w in string.split()]\n",
    "    )\n",
    "    string = re.sub('\\(dot\\)', '.', string)\n",
    "    string = (\n",
    "        re.sub(re.findall(r'\\<a(.*?)\\>', string)[0], '', string)\n",
    "        if (len(re.findall(r'\\<a (.*?)\\>', string)) > 0)\n",
    "        and ('href' in re.findall(r'\\<a (.*?)\\>', string)[0])\n",
    "        else string\n",
    "    )\n",
    "    string = re.sub(\n",
    "        r'\\w+:\\/{2}[\\d\\w-]+(\\.[\\d\\w-]+)*(?:(?:\\/[^\\s/]*))*', ' ', string\n",
    "    )\n",
    "    \n",
    "    chars = '.,/'\n",
    "    for c in chars:\n",
    "        string = string.replace(c, f' {c} ')\n",
    "        \n",
    "    string = re.sub(r'[ ]+', ' ', string).strip().split()\n",
    "    string = [w for w in string if w[0] != '@']\n",
    "    x = []\n",
    "    for word in string:\n",
    "        word = word.lower()\n",
    "        if any([laugh in word for laugh in laughing]):\n",
    "            if random.random() >= 0.5:\n",
    "                x.append(word)\n",
    "        else:\n",
    "            x.append(word)\n",
    "    string = [w.title() if w[0].isupper() else w for w in x]\n",
    "    return ' '.join(string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['severe toxic',\n",
       " 'obscene',\n",
       " 'identity attack',\n",
       " 'insult',\n",
       " 'threat',\n",
       " 'asian',\n",
       " 'atheist',\n",
       " 'bisexual',\n",
       " 'black',\n",
       " 'buddhist',\n",
       " 'christian',\n",
       " 'female',\n",
       " 'heterosexual',\n",
       " 'indian',\n",
       " 'homosexual, gay or lesbian',\n",
       " 'intellectual or learning disability',\n",
       " 'jewish',\n",
       " 'latino',\n",
       " 'male',\n",
       " 'muslim',\n",
       " 'other disability',\n",
       " 'other gender',\n",
       " 'other race or ethnicity',\n",
       " 'other religion',\n",
       " 'other sexual orientation',\n",
       " 'physical disability',\n",
       " 'psychiatric or mental illness',\n",
       " 'transgender',\n",
       " 'white',\n",
       " 'malay',\n",
       " 'chinese']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = \"\"\"\n",
    "1. severe toxic\n",
    "2. obscene\n",
    "3. identity attack\n",
    "4. insult\n",
    "5. threat\n",
    "6. asian\n",
    "7. atheist\n",
    "8. bisexual\n",
    "9. black\n",
    "10. buddhist\n",
    "11. christian\n",
    "12. female\n",
    "13. heterosexual\n",
    "14. indian\n",
    "15. homosexual, gay or lesbian\n",
    "16. intellectual or learning disability\n",
    "17. jewish\n",
    "18. latino\n",
    "19. male\n",
    "20. muslim\n",
    "21. other disability\n",
    "22. other gender\n",
    "23. other race or ethnicity\n",
    "24. other religion\n",
    "25. other sexual orientation\n",
    "26. physical disability\n",
    "27. psychiatric or mental illness\n",
    "28. transgender\n",
    "29. white\n",
    "30. malay\n",
    "31. chinese\n",
    "\"\"\"\n",
    "labels = [l.split('. ')[1].strip() for l in labels.split('\\n') if len(l)]\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../toxicity/translated-1750000.json',\n",
       " '../toxicity/translated-1450000.json',\n",
       " '../toxicity/translated-700000.json',\n",
       " '../toxicity/translated-350000.json',\n",
       " '../toxicity/translated-600000.json',\n",
       " '../toxicity/translated-900000.json',\n",
       " '../toxicity/translated-1000000.json',\n",
       " '../toxicity/translated-1100000.json',\n",
       " '../toxicity/translated-550000.json',\n",
       " '../toxicity/translated-150000.json',\n",
       " '../toxicity/translated-500000.json',\n",
       " '../toxicity/translated-1500000.json',\n",
       " '../toxicity/translated-1150000.json',\n",
       " '../toxicity/translated-750000.json',\n",
       " '../toxicity/translated-850000.json',\n",
       " '../toxicity/translated-1650000.json',\n",
       " '../toxicity/translated-300000.json',\n",
       " '../toxicity/translated-650000.json',\n",
       " '../toxicity/translated-950000.json',\n",
       " '../toxicity/translated-250000.json',\n",
       " '../toxicity/translated-1600000.json',\n",
       " '../toxicity/translated-0.json',\n",
       " '../toxicity/translated-1550000.json',\n",
       " '../toxicity/translated-1800000.json',\n",
       " '../toxicity/translated-450000.json',\n",
       " '../toxicity/translated-50000.json',\n",
       " '../toxicity/translated-1050000.json',\n",
       " '../toxicity/translated-1200000.json',\n",
       " '../toxicity/translated-1700000.json',\n",
       " '../toxicity/translated-400000.json']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import glob\n",
    "\n",
    "files = glob.glob('../toxicity/translated*')\n",
    "files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../toxicity/translated-1750000.json\n",
      "../toxicity/translated-1450000.json\n",
      "../toxicity/translated-700000.json\n",
      "../toxicity/translated-350000.json\n",
      "../toxicity/translated-600000.json\n",
      "../toxicity/translated-900000.json\n",
      "../toxicity/translated-1000000.json\n",
      "../toxicity/translated-1100000.json\n",
      "../toxicity/translated-550000.json\n",
      "../toxicity/translated-150000.json\n",
      "../toxicity/translated-500000.json\n",
      "../toxicity/translated-1500000.json\n",
      "../toxicity/translated-1150000.json\n",
      "../toxicity/translated-750000.json\n",
      "../toxicity/translated-850000.json\n",
      "../toxicity/translated-1650000.json\n",
      "../toxicity/translated-300000.json\n",
      "../toxicity/translated-650000.json\n",
      "../toxicity/translated-950000.json\n",
      "../toxicity/translated-250000.json\n",
      "../toxicity/translated-1600000.json\n",
      "../toxicity/translated-0.json\n",
      "../toxicity/translated-1550000.json\n",
      "../toxicity/translated-1800000.json\n",
      "../toxicity/translated-450000.json\n",
      "../toxicity/translated-50000.json\n",
      "../toxicity/translated-1050000.json\n",
      "../toxicity/translated-1200000.json\n",
      "../toxicity/translated-1700000.json\n",
      "../toxicity/translated-400000.json\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1401054"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "X, Y = [], []\n",
    "\n",
    "for file in files:\n",
    "    print(file)\n",
    "    with open(file) as fopen:\n",
    "        f = json.load(fopen)\n",
    "        for row in f:\n",
    "            if len(row[1]) == 29:\n",
    "                X.append(row[0])\n",
    "                Y.append(row[1] + [0, 0])\n",
    "        \n",
    "    \n",
    "len(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rejected_labels = ['black', 'white', 'jewish', 'latino']\n",
    "[labels.index(l) for l in rejected_labels]\n",
    "labels = [l for l in labels if l not in rejected_labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ydf = pd.DataFrame(np.array(Y))\n",
    "ydf = ydf.loc[(ydf[8] == 0) & (ydf[28] == 0) & (ydf[16] == 0) & (ydf[17] == 0)]\n",
    "ydf = ydf.drop([8, 28, 16, 17], axis = 1)\n",
    "ix = ydf.index.tolist()\n",
    "Y = ydf.values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1361040, 1361040)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = [X[i] for i in ix]\n",
    "len(X), len(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {'severe_toxic': 'severe toxic', 'identity_hate': 'identity attack',\n",
    "          'toxic': 'severe toxic', 'melayu': 'malay', 'cina': 'chinese', 'india': 'indian'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_onehot(tags, depth = len(labels)):\n",
    "    onehot = [0] * depth\n",
    "    for tag in tags:\n",
    "        onehot[labels.index(tag)] = 1\n",
    "    return onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "melayu 84851\n",
      "cina 43956\n",
      "india 20208\n"
     ]
    }
   ],
   "source": [
    "with open('../toxicity/kaum.json') as fopen:\n",
    "    kaum = json.load(fopen)\n",
    "    \n",
    "for k, v in kaum.items():\n",
    "    print(k, len(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../toxicity/weak-learning-toxicity.json') as fopen:\n",
    "    scores = json.load(fopen)\n",
    "    \n",
    "for k, v in scores.items():\n",
    "    for no in range(len(v)):\n",
    "        tags = []\n",
    "        for l, v_ in v[no].items():\n",
    "            if round(v_) == 1:\n",
    "                tags.append(mapping.get(l, l))\n",
    "        tags.append(mapping[k])\n",
    "        Y.append(generate_onehot(tags))\n",
    "        X.append(kaum[k][no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1510055/1510055 [04:46<00:00, 5272.78it/s] \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for i in tqdm(range(len(X))):\n",
    "    X[i] = cleaning(X[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1510055/1510055 [00:01<00:00, 1288321.04it/s]\n"
     ]
    }
   ],
   "source": [
    "actual_t, actual_l = [], []\n",
    "\n",
    "for i in tqdm(range(len(X))):\n",
    "    if len(X[i]) > 2:\n",
    "        actual_t.append(X[i])\n",
    "        actual_l.append(Y[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1502317/1502317 [09:50<00:00, 2544.56it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "MAX_SEQ_LENGTH = 150\n",
    "input_ids, input_masks, segment_ids = [], [], []\n",
    "\n",
    "for text in tqdm(actual_t):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "    tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
    "    input_id += padding\n",
    "    input_mask += padding\n",
    "    segment_id += padding\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)\n",
    "    segment_ids.append(segment_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(actual_t) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_initializer(initializer_range=0.02):\n",
    "    return tf.truncated_normal_initializer(stddev=initializer_range)\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "        training = True\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.MASK = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.float32, [None, dimension_output])\n",
    "        \n",
    "        model = modeling.AlbertModel(\n",
    "            config=albert_config,\n",
    "            is_training=training,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.MASK,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_sequence_output()\n",
    "        output_layer = tf.layers.dense(\n",
    "            output_layer,\n",
    "            albert_config.hidden_size,\n",
    "            activation=tf.tanh,\n",
    "            kernel_initializer=create_initializer())\n",
    "        self.logits_seq = tf.layers.dense(output_layer, dimension_output,\n",
    "                                         kernel_initializer=create_initializer())\n",
    "        self.logits_seq = tf.identity(self.logits_seq, name = 'logits_seq')\n",
    "        self.logits = self.logits_seq[:, 0]\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sigmoid_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        \n",
    "        correct_prediction = tf.equal(tf.round(tf.nn.sigmoid(self.logits)), tf.round(self.Y))\n",
    "        all_labels_true = tf.reduce_min(tf.cast(correct_prediction, tf.float32), 1)\n",
    "        self.accuracy = tf.reduce_mean(all_labels_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_INIT_CHKPNT = 'albert-tiny-2020-04-17/model.ckpt-1000000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:194: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:507: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:588: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:1025: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:253: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:36: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:41: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "INFO:tensorflow:++++++ warmup starts at step 0, for 25038 steps ++++++\n",
      "INFO:tensorflow:using adamw\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:101: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
      "\n",
      "INFO:tensorflow:Restoring parameters from albert-tiny-2020-04-17/model.ckpt-1000000\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(labels)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_input_ids, test_input_ids, train_Y, test_Y, train_mask, test_mask = train_test_split(\n",
    "    input_ids, actual_l, input_masks, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20031/20031 [27:58<00:00, 11.93it/s, accuracy=0.962, cost=0.0204]\n",
      "test minibatch loop: 100%|██████████| 5008/5008 [02:41<00:00, 30.95it/s, accuracy=0.955, cost=0.0238]\n",
      "train minibatch loop:   0%|          | 2/20031 [00:00<27:59, 11.93it/s, accuracy=0.9, cost=0.024]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.907103\n",
      "time taken: 1840.1854791641235\n",
      "epoch: 0, training loss: 0.094182, training acc: 0.833978, valid loss: 0.023948, valid acc: 0.907103\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20031/20031 [27:59<00:00, 11.93it/s, accuracy=0.943, cost=0.0199] \n",
      "test minibatch loop: 100%|██████████| 5008/5008 [02:43<00:00, 30.69it/s, accuracy=0.977, cost=0.0223] \n",
      "train minibatch loop:   0%|          | 2/20031 [00:00<27:49, 11.99it/s, accuracy=0.9, cost=0.0203]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.907103, current acc: 0.916076\n",
      "time taken: 1842.6004388332367\n",
      "epoch: 1, training loss: 0.022479, training acc: 0.912580, valid loss: 0.021688, valid acc: 0.916076\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20031/20031 [28:00<00:00, 11.92it/s, accuracy=0.962, cost=0.0198] \n",
      "test minibatch loop: 100%|██████████| 5008/5008 [02:42<00:00, 30.82it/s, accuracy=0.977, cost=0.0223] \n",
      "train minibatch loop:   0%|          | 2/20031 [00:00<28:00, 11.92it/s, accuracy=0.933, cost=0.0196]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.916076, current acc: 0.918096\n",
      "time taken: 1843.135541677475\n",
      "epoch: 2, training loss: 0.021161, training acc: 0.918443, valid loss: 0.021274, valid acc: 0.918096\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20031/20031 [27:59<00:00, 11.92it/s, accuracy=0.943, cost=0.0197] \n",
      "test minibatch loop: 100%|██████████| 5008/5008 [02:43<00:00, 30.69it/s, accuracy=0.977, cost=0.0226] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 1843.0814881324768\n",
      "epoch: 3, training loss: 0.020534, training acc: 0.922090, valid loss: 0.021149, valid acc: 0.917814\n",
      "\n",
      "break epoch:4\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = train_input_ids[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        batch_mask = train_mask[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.MASK: batch_mask\n",
    "            },\n",
    "        )\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = test_input_ids[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        batch_mask = test_mask[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.MASK: batch_mask\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "        \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'albert-tiny-toxic/model.ckpt'"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'albert-tiny-toxic/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:++++++ warmup starts at step 0, for 25038 steps ++++++\n",
      "INFO:tensorflow:using adamw\n",
      "INFO:tensorflow:Restoring parameters from albert-tiny-toxic/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(labels)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate,\n",
    "    training=False\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.restore(sess, 'albert-tiny-toxic/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 5008/5008 [16:36<00:00,  5.03it/s]\n"
     ]
    }
   ],
   "source": [
    "stack = []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_input_ids), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_input_ids))\n",
    "    batch_x = test_input_ids[i: index]\n",
    "    batch_mask = test_mask[i: index]\n",
    "    stack.append(sess.run(tf.nn.sigmoid(model.logits),\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.MASK: batch_mask\n",
    "            },\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                     precision    recall  f1-score   support\n",
      "\n",
      "                       severe toxic    0.78533   0.72620   0.75460      9788\n",
      "                            obscene    0.67641   0.33796   0.45072      2808\n",
      "                    identity attack    0.66042   0.22988   0.34104      1379\n",
      "                             insult    0.74085   0.47457   0.57854     12662\n",
      "                             threat    0.52941   0.02153   0.04138       418\n",
      "                              asian    0.65027   0.29975   0.41034       397\n",
      "                            atheist    0.85882   0.82022   0.83908       178\n",
      "                           bisexual    1.00000   0.03125   0.06061        32\n",
      "                           buddhist    0.73333   0.26190   0.38596        42\n",
      "                          christian    0.87017   0.84438   0.85708      4723\n",
      "                             female    0.85865   0.92302   0.88967      6963\n",
      "                       heterosexual    0.76147   0.70339   0.73128       118\n",
      "                             indian    0.93209   0.90115   0.91636      4097\n",
      "         homosexual, gay or lesbian    0.89625   0.89690   0.89658      1387\n",
      "intellectual or learning disability    0.00000   0.00000   0.00000         7\n",
      "                               male    0.68679   0.62619   0.65509      4941\n",
      "                             muslim    0.86187   0.87102   0.86642      2543\n",
      "                   other disability    0.00000   0.00000   0.00000         0\n",
      "                       other gender    0.00000   0.00000   0.00000         0\n",
      "            other race or ethnicity    0.00000   0.00000   0.00000         8\n",
      "                     other religion    0.00000   0.00000   0.00000         8\n",
      "           other sexual orientation    0.00000   0.00000   0.00000         0\n",
      "                physical disability    0.00000   0.00000   0.00000         1\n",
      "      psychiatric or mental illness    0.74208   0.57243   0.64631       573\n",
      "                        transgender    0.79327   0.76037   0.77647       217\n",
      "                              malay    0.99392   0.93774   0.96501     16896\n",
      "                            chinese    0.89948   0.96317   0.93024      8770\n",
      "\n",
      "                          micro avg    0.85977   0.76240   0.80816     78956\n",
      "                          macro avg    0.59003   0.45196   0.48121     78956\n",
      "                       weighted avg    0.84448   0.76240   0.79239     78956\n",
      "                        samples avg    0.15465   0.15102   0.15056     78956\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(metrics.classification_report(np.around(np.array(test_Y)),\n",
    "                                    np.around(np.concatenate(stack,axis=0)),\n",
    "                                    target_names=labels,\n",
    "                                    digits=5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'bert/embeddings/word_embeddings',\n",
       " 'bert/embeddings/token_type_embeddings',\n",
       " 'bert/embeddings/position_embeddings',\n",
       " 'bert/embeddings/LayerNorm/gamma',\n",
       " 'bert/encoder/embedding_hidden_mapping_in/kernel',\n",
       " 'bert/encoder/embedding_hidden_mapping_in/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma',\n",
       " 'bert/pooler/dense/kernel',\n",
       " 'bert/pooler/dense/bias',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'dense_1/kernel',\n",
       " 'dense_1/bias',\n",
       " 'logits_seq',\n",
       " 'logits']"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from albert-tiny-toxic/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-31-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 29 variables.\n",
      "INFO:tensorflow:Converted 29 variables to const ops.\n",
      "1226 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('albert-tiny-toxic', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "bucketName = 'huseinhouse-storage'\n",
    "Key = 'albert-tiny-toxic/frozen_model.pb'\n",
    "outPutname = \"v34/toxicity/albert-tiny-toxicity.pb\"\n",
    "\n",
    "s3 = boto3.client('s3')\n",
    "s3.upload_file(Key,bucketName,outPutname)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
